{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:08.101994Z",
     "start_time": "2019-11-06T14:39:07.152077Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bk-root\">\n",
       "        <a href=\"https://bokeh.pydata.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
       "        <span id=\"1001\">Loading BokehJS ...</span>\n",
       "    </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "(function(root) {\n",
       "  function now() {\n",
       "    return new Date();\n",
       "  }\n",
       "\n",
       "  var force = true;\n",
       "\n",
       "  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n",
       "    root._bokeh_onload_callbacks = [];\n",
       "    root._bokeh_is_loading = undefined;\n",
       "  }\n",
       "\n",
       "  var JS_MIME_TYPE = 'application/javascript';\n",
       "  var HTML_MIME_TYPE = 'text/html';\n",
       "  var EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
       "  var CLASS_NAME = 'output_bokeh rendered_html';\n",
       "\n",
       "  /**\n",
       "   * Render data to the DOM node\n",
       "   */\n",
       "  function render(props, node) {\n",
       "    var script = document.createElement(\"script\");\n",
       "    node.appendChild(script);\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when an output is cleared or removed\n",
       "   */\n",
       "  function handleClearOutput(event, handle) {\n",
       "    var cell = handle.cell;\n",
       "\n",
       "    var id = cell.output_area._bokeh_element_id;\n",
       "    var server_id = cell.output_area._bokeh_server_id;\n",
       "    // Clean up Bokeh references\n",
       "    if (id != null && id in Bokeh.index) {\n",
       "      Bokeh.index[id].model.document.clear();\n",
       "      delete Bokeh.index[id];\n",
       "    }\n",
       "\n",
       "    if (server_id !== undefined) {\n",
       "      // Clean up Bokeh references\n",
       "      var cmd = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
       "      cell.notebook.kernel.execute(cmd, {\n",
       "        iopub: {\n",
       "          output: function(msg) {\n",
       "            var id = msg.content.text.trim();\n",
       "            if (id in Bokeh.index) {\n",
       "              Bokeh.index[id].model.document.clear();\n",
       "              delete Bokeh.index[id];\n",
       "            }\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "      // Destroy server and session\n",
       "      var cmd = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
       "      cell.notebook.kernel.execute(cmd);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when a new output is added\n",
       "   */\n",
       "  function handleAddOutput(event, handle) {\n",
       "    var output_area = handle.output_area;\n",
       "    var output = handle.output;\n",
       "\n",
       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
       "    if ((output.output_type != \"display_data\") || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n",
       "      return\n",
       "    }\n",
       "\n",
       "    var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
       "\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
       "      // store reference to embed id on output_area\n",
       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
       "    }\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
       "      var bk_div = document.createElement(\"div\");\n",
       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
       "      var script_attrs = bk_div.children[0].attributes;\n",
       "      for (var i = 0; i < script_attrs.length; i++) {\n",
       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
       "      }\n",
       "      // store reference to server id on output_area\n",
       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
       "    }\n",
       "  }\n",
       "\n",
       "  function register_renderer(events, OutputArea) {\n",
       "\n",
       "    function append_mime(data, metadata, element) {\n",
       "      // create a DOM node to render to\n",
       "      var toinsert = this.create_output_subarea(\n",
       "        metadata,\n",
       "        CLASS_NAME,\n",
       "        EXEC_MIME_TYPE\n",
       "      );\n",
       "      this.keyboard_manager.register_events(toinsert);\n",
       "      // Render to node\n",
       "      var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
       "      render(props, toinsert[toinsert.length - 1]);\n",
       "      element.append(toinsert);\n",
       "      return toinsert\n",
       "    }\n",
       "\n",
       "    /* Handle when an output is cleared or removed */\n",
       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
       "    events.on('delete.Cell', handleClearOutput);\n",
       "\n",
       "    /* Handle when a new output is added */\n",
       "    events.on('output_added.OutputArea', handleAddOutput);\n",
       "\n",
       "    /**\n",
       "     * Register the mime type and append_mime function with output_area\n",
       "     */\n",
       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
       "      /* Is output safe? */\n",
       "      safe: true,\n",
       "      /* Index of renderer in `output_area.display_order` */\n",
       "      index: 0\n",
       "    });\n",
       "  }\n",
       "\n",
       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
       "  if (root.Jupyter !== undefined) {\n",
       "    var events = require('base/js/events');\n",
       "    var OutputArea = require('notebook/js/outputarea').OutputArea;\n",
       "\n",
       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
       "      register_renderer(events, OutputArea);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  \n",
       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
       "    root._bokeh_timeout = Date.now() + 5000;\n",
       "    root._bokeh_failed_load = false;\n",
       "  }\n",
       "\n",
       "  var NB_LOAD_WARNING = {'data': {'text/html':\n",
       "     \"<div style='background-color: #fdd'>\\n\"+\n",
       "     \"<p>\\n\"+\n",
       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
       "     \"</p>\\n\"+\n",
       "     \"<ul>\\n\"+\n",
       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
       "     \"</ul>\\n\"+\n",
       "     \"<code>\\n\"+\n",
       "     \"from bokeh.resources import INLINE\\n\"+\n",
       "     \"output_notebook(resources=INLINE)\\n\"+\n",
       "     \"</code>\\n\"+\n",
       "     \"</div>\"}};\n",
       "\n",
       "  function display_loaded() {\n",
       "    var el = document.getElementById(\"1001\");\n",
       "    if (el != null) {\n",
       "      el.textContent = \"BokehJS is loading...\";\n",
       "    }\n",
       "    if (root.Bokeh !== undefined) {\n",
       "      if (el != null) {\n",
       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
       "      }\n",
       "    } else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(display_loaded, 100)\n",
       "    }\n",
       "  }\n",
       "\n",
       "\n",
       "  function run_callbacks() {\n",
       "    try {\n",
       "      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n",
       "    }\n",
       "    finally {\n",
       "      delete root._bokeh_onload_callbacks\n",
       "    }\n",
       "    console.info(\"Bokeh: all callbacks have finished\");\n",
       "  }\n",
       "\n",
       "  function load_libs(js_urls, callback) {\n",
       "    root._bokeh_onload_callbacks.push(callback);\n",
       "    if (root._bokeh_is_loading > 0) {\n",
       "      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
       "      return null;\n",
       "    }\n",
       "    if (js_urls == null || js_urls.length === 0) {\n",
       "      run_callbacks();\n",
       "      return null;\n",
       "    }\n",
       "    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
       "    root._bokeh_is_loading = js_urls.length;\n",
       "    for (var i = 0; i < js_urls.length; i++) {\n",
       "      var url = js_urls[i];\n",
       "      var s = document.createElement('script');\n",
       "      s.src = url;\n",
       "      s.async = false;\n",
       "      s.onreadystatechange = s.onload = function() {\n",
       "        root._bokeh_is_loading--;\n",
       "        if (root._bokeh_is_loading === 0) {\n",
       "          console.log(\"Bokeh: all BokehJS libraries loaded\");\n",
       "          run_callbacks()\n",
       "        }\n",
       "      };\n",
       "      s.onerror = function() {\n",
       "        console.warn(\"failed to load library \" + url);\n",
       "      };\n",
       "      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
       "      document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "    }\n",
       "  };var element = document.getElementById(\"1001\");\n",
       "  if (element == null) {\n",
       "    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n",
       "    return false;\n",
       "  }\n",
       "\n",
       "  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n",
       "\n",
       "  var inline_js = [\n",
       "    function(Bokeh) {\n",
       "      Bokeh.set_log_level(\"info\");\n",
       "    },\n",
       "    \n",
       "    function(Bokeh) {\n",
       "      \n",
       "    },\n",
       "    function(Bokeh) {\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "    }\n",
       "  ];\n",
       "\n",
       "  function run_inline_js() {\n",
       "    \n",
       "    if ((root.Bokeh !== undefined) || (force === true)) {\n",
       "      for (var i = 0; i < inline_js.length; i++) {\n",
       "        inline_js[i].call(root, root.Bokeh);\n",
       "      }if (force === true) {\n",
       "        display_loaded();\n",
       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(run_inline_js, 100);\n",
       "    } else if (!root._bokeh_failed_load) {\n",
       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
       "      root._bokeh_failed_load = true;\n",
       "    } else if (force !== true) {\n",
       "      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n",
       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
       "    }\n",
       "\n",
       "  }\n",
       "\n",
       "  if (root._bokeh_is_loading === 0) {\n",
       "    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
       "    run_inline_js();\n",
       "  } else {\n",
       "    load_libs(js_urls, function() {\n",
       "      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n",
       "      run_inline_js();\n",
       "    });\n",
       "  }\n",
       "}(window));"
      ],
      "application/vnd.bokehjs_load.v0+json": "\n(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  var force = true;\n\n  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n  \n\n  \n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  var NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    var el = document.getElementById(\"1001\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n    }\n    finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.info(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(js_urls, callback) {\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = js_urls.length;\n    for (var i = 0; i < js_urls.length; i++) {\n      var url = js_urls[i];\n      var s = document.createElement('script');\n      s.src = url;\n      s.async = false;\n      s.onreadystatechange = s.onload = function() {\n        root._bokeh_is_loading--;\n        if (root._bokeh_is_loading === 0) {\n          console.log(\"Bokeh: all BokehJS libraries loaded\");\n          run_callbacks()\n        }\n      };\n      s.onerror = function() {\n        console.warn(\"failed to load library \" + url);\n      };\n      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.getElementsByTagName(\"head\")[0].appendChild(s);\n    }\n  };var element = document.getElementById(\"1001\");\n  if (element == null) {\n    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n    return false;\n  }\n\n  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n\n  var inline_js = [\n    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\n    \n    function(Bokeh) {\n      \n    },\n    function(Bokeh) {\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n    }\n  ];\n\n  function run_inline_js() {\n    \n    if ((root.Bokeh !== undefined) || (force === true)) {\n      for (var i = 0; i < inline_js.length; i++) {\n        inline_js[i].call(root, root.Bokeh);\n      }if (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(js_urls, function() {\n      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from bokeh.io import show, output_notebook\n",
    "from bokeh.plotting import figure, gridplot\n",
    "from bokeh.models import LinearAxis, Range1d\n",
    "output_notebook()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPU\n",
    "指定使用的 GPU 编号。  \n",
    "`watch -n 1 nvidia-smi` 实时查看 GPU 的运行状态。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:08.338393Z",
     "start_time": "2019-11-06T14:39:08.103740Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.set_device(\"cuda:3\")\n",
    "torch.cuda.current_device()\n",
    "# device = torch.device(\"cuda:5\")\n",
    "# xxx.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-02T15:34:18.036850Z",
     "start_time": "2019-11-02T15:34:18.032937Z"
    }
   },
   "source": [
    "### Data\n",
    "通过`torchvision.datasets`下载`MNIST`数据。  \n",
    "训练集：`train=True`  \n",
    "测试集：`train=False`  \n",
    "常用的还有`torchvision.datasets.ImageFolder()`，按文件夹取图片。  \n",
    "\n",
    "`torchvision.transforms`可以对图片做处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:08.491812Z",
     "start_time": "2019-11-06T14:39:08.340801Z"
    }
   },
   "outputs": [],
   "source": [
    "train_dataset = dsets.MNIST(root='../dataset', train=True, transform=transforms.ToTensor(), download=True)\n",
    "test_dataset = dsets.MNIST(root='../dataset', train=False, transform=transforms.ToTensor(), download=True)\n",
    "\n",
    "batch_size = 100\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "用 双向 LSTM 构造 RNN。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里定义的 LSTM 满足 `batch_first = True`, 即输入和输出的 Tensor 的维度中，`batch_size` 要放在第一个。  \n",
    "输入的 `Tensor(batch_size, seq_len, dim)`  \n",
    "否则，当 `batch_first = True`，则有 `Tensor(seq_len, batch_size, dim)`，通过 `X = X.transpose(0,1)`，即交换前两维数据得到。  \n",
    "\n",
    "作为单元组成部分的 Hidden 和 Cell 向量也有对应的维度，为 `Tensor(layer_size, batch_size, dim_hidden)`.\n",
    "\n",
    "经过 LSTM 输出的结果，只要保留最后一个时刻的结果，即在 `seq_len` 的一栏填 `-1`. 使其成为 `Tensor(batch_size, dim)`，再通过线性网络输出。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里只对单向的 LSTM 做一个小改动，即加入参数 `bidirectional=True`. 因此输出的 Tensor dim 会加倍。\n",
    "LSTM 单元会加倍，因此 Hidden 和 Cell 为 `Tensor(layer_size*2, batch_size, dim_hidden)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:08.555613Z",
     "start_time": "2019-11-06T14:39:08.494270Z"
    }
   },
   "outputs": [],
   "source": [
    "class BiRNN(nn.Module):\n",
    "    def __init__(self, dim_in, dim_hid, dim_out, layers_size):\n",
    "        super().__init__()\n",
    "        self.dim_hid = dim_hid\n",
    "        self.layers_size = layers_size\n",
    "        self.lstm = nn.LSTM(dim_in, dim_hid, layers_size, batch_first=True, bidirectional=True)\n",
    "        self.fc = nn.Linear(dim_hid*2, dim_out)\n",
    "    def forward(self, x, batch_size):\n",
    "        h0 = torch.zeros(self.layers_size*2, batch_size, self.dim_hid).cuda()\n",
    "        c0 = torch.zeros(self.layers_size*2, batch_size, self.dim_hid).cuda()\n",
    "        x, _ = self.lstm(x, (h0, c0))\n",
    "        return self.fc(x[:, -1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:14.525525Z",
     "start_time": "2019-11-06T14:39:08.557285Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BiRNN(\n",
       "  (lstm): LSTM(28, 128, num_layers=2, batch_first=True, bidirectional=True)\n",
       "  (fc): Linear(in_features=256, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lrate = 0.01\n",
    "epochs = 2\n",
    "sequence_length = 28\n",
    "dim_in = 28\n",
    "dim_hid = 128\n",
    "layers_size = 2\n",
    "num_classes = 10\n",
    "\n",
    "model = BiRNN(dim_in, dim_hid, num_classes, layers_size).cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optim = torch.optim.Adam(model.parameters(), lr=lrate)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：每次反向传播的时候都需要将参数的梯度归零。  \n",
    "`optim.step()`则在每个`Variable`的`grad`都被计算出来后，更新每个`Variable`的数值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在每次训练中都用`train_loader`中的一个`batch`作为训练数据。  \n",
    "`Tensor.cuda()` 每个 `batch` 在实际使用之前，都先移入 `GPU` 后进行计算。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:39:41.906708Z",
     "start_time": "2019-11-06T14:39:14.528643Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "  <div class=\"bk-root\" id=\"fc1674fa-dd66-46c6-9b0c-30533d00480e\"></div>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "(function(root) {\n",
       "  function embed_document(root) {\n",
       "    \n",
       "  var docs_json = {\"790fbe8e-f23b-47f7-bba0-6d1da9d476f2\":{\"roots\":{\"references\":[{\"attributes\":{\"below\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"}],\"left\":[{\"id\":\"1016\",\"type\":\"LinearAxis\"}],\"renderers\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"id\":\"1015\",\"type\":\"Grid\"},{\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"id\":\"1020\",\"type\":\"Grid\"},{\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"id\":\"1039\",\"type\":\"GlyphRenderer\"}],\"title\":{\"id\":\"1042\",\"type\":\"Title\"},\"toolbar\":{\"id\":\"1027\",\"type\":\"Toolbar\"},\"x_range\":{\"id\":\"1003\",\"type\":\"DataRange1d\"},\"x_scale\":{\"id\":\"1007\",\"type\":\"LinearScale\"},\"y_range\":{\"id\":\"1005\",\"type\":\"DataRange1d\"},\"y_scale\":{\"id\":\"1009\",\"type\":\"LinearScale\"}},\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},{\"attributes\":{\"bottom_units\":\"screen\",\"fill_alpha\":{\"value\":0.5},\"fill_color\":{\"value\":\"lightgrey\"},\"left_units\":\"screen\",\"level\":\"overlay\",\"line_alpha\":{\"value\":1.0},\"line_color\":{\"value\":\"black\"},\"line_dash\":[4,4],\"line_width\":{\"value\":2},\"plot\":null,\"render_mode\":\"css\",\"right_units\":\"screen\",\"top_units\":\"screen\"},\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"attributes\":{},\"id\":\"1009\",\"type\":\"LinearScale\"},{\"attributes\":{\"data_source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"},\"glyph\":{\"id\":\"1037\",\"type\":\"Line\"},\"hover_glyph\":null,\"muted_glyph\":null,\"nonselection_glyph\":{\"id\":\"1038\",\"type\":\"Line\"},\"selection_glyph\":null,\"view\":{\"id\":\"1040\",\"type\":\"CDSView\"}},\"id\":\"1039\",\"type\":\"GlyphRenderer\"},{\"attributes\":{\"formatter\":{\"id\":\"1045\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"attributes\":{\"source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"}},\"id\":\"1040\",\"type\":\"CDSView\"},{\"attributes\":{},\"id\":\"1012\",\"type\":\"BasicTicker\"},{\"attributes\":{\"plot\":null,\"text\":\"\"},\"id\":\"1042\",\"type\":\"Title\"},{\"attributes\":{\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1015\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1043\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{\"formatter\":{\"id\":\"1043\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"attributes\":{},\"id\":\"1045\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{},\"id\":\"1017\",\"type\":\"BasicTicker\"},{\"attributes\":{},\"id\":\"1048\",\"type\":\"Selection\"},{\"attributes\":{\"dimension\":1,\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1020\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1049\",\"type\":\"UnionRenderers\"},{\"attributes\":{\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1037\",\"type\":\"Line\"},{\"attributes\":{\"callback\":null,\"data\":{\"x\":[0,1,2,3,4,5,6,7,8,9,10,11],\"y\":[2.2956762313842773,0.3041001558303833,0.19128233194351196,0.14736545085906982,0.12811286747455597,0.17608138918876648,0.12269347906112671,0.1628677248954773,0.08360164612531662,0.021787257865071297,0.18575935065746307,0.15738791227340698]},\"selected\":{\"id\":\"1048\",\"type\":\"Selection\"},\"selection_policy\":{\"id\":\"1049\",\"type\":\"UnionRenderers\"}},\"id\":\"1036\",\"type\":\"ColumnDataSource\"},{\"attributes\":{\"callback\":null},\"id\":\"1005\",\"type\":\"DataRange1d\"},{\"attributes\":{},\"id\":\"1021\",\"type\":\"PanTool\"},{\"attributes\":{},\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"attributes\":{\"overlay\":{\"id\":\"1029\",\"type\":\"BoxAnnotation\"}},\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"attributes\":{\"callback\":null},\"id\":\"1003\",\"type\":\"DataRange1d\"},{\"attributes\":{},\"id\":\"1024\",\"type\":\"SaveTool\"},{\"attributes\":{},\"id\":\"1025\",\"type\":\"ResetTool\"},{\"attributes\":{},\"id\":\"1026\",\"type\":\"HelpTool\"},{\"attributes\":{\"active_drag\":\"auto\",\"active_inspect\":\"auto\",\"active_multi\":null,\"active_scroll\":\"auto\",\"active_tap\":\"auto\",\"tools\":[{\"id\":\"1021\",\"type\":\"PanTool\"},{\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"id\":\"1024\",\"type\":\"SaveTool\"},{\"id\":\"1025\",\"type\":\"ResetTool\"},{\"id\":\"1026\",\"type\":\"HelpTool\"}]},\"id\":\"1027\",\"type\":\"Toolbar\"},{\"attributes\":{\"line_alpha\":0.1,\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1038\",\"type\":\"Line\"},{\"attributes\":{},\"id\":\"1007\",\"type\":\"LinearScale\"}],\"root_ids\":[\"1002\"]},\"title\":\"Bokeh Application\",\"version\":\"1.0.0\"}};\n",
       "  var render_items = [{\"docid\":\"790fbe8e-f23b-47f7-bba0-6d1da9d476f2\",\"roots\":{\"1002\":\"fc1674fa-dd66-46c6-9b0c-30533d00480e\"}}];\n",
       "  root.Bokeh.embed.embed_items_notebook(docs_json, render_items);\n",
       "\n",
       "  }\n",
       "  if (root.Bokeh !== undefined) {\n",
       "    embed_document(root);\n",
       "  } else {\n",
       "    var attempts = 0;\n",
       "    var timer = setInterval(function(root) {\n",
       "      if (root.Bokeh !== undefined) {\n",
       "        embed_document(root);\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "      attempts++;\n",
       "      if (attempts > 100) {\n",
       "        console.log(\"Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing\");\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "    }, 10, root)\n",
       "  }\n",
       "})(window);"
      ],
      "application/vnd.bokehjs_exec.v0+json": ""
     },
     "metadata": {
      "application/vnd.bokehjs_exec.v0+json": {
       "id": "1002"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = []\n",
    "for e in range(epochs):\n",
    "    for i, (inputs, targets) in enumerate(train_loader):\n",
    "        inputs = inputs.squeeze(1).cuda()\n",
    "        targets = targets.cuda()\n",
    "        optim.zero_grad() \n",
    "        outputs = model(inputs, batch_size)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        if i % 100 == 0:\n",
    "            result.append(float(loss))\n",
    "fig = figure()\n",
    "fig.line(range(len(result)), result)\n",
    "show(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`re = torch.max(Tensor,dim)`, 返回的re为一个二维向量，其中`re[0]`为最大值的`Tensor`，re[1]为最大值对应的`index`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:40:05.193300Z",
     "start_time": "2019-11-06T14:39:41.908458Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the model on the 10000 test images: 97.57 %\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "wrong_count = 0\n",
    "wrong_classify = []\n",
    "for i, (inputs, targets) in enumerate(test_loader):\n",
    "    inputs = inputs.squeeze(1).cuda()\n",
    "    outputs = model(inputs, 1)\n",
    "    _, preds = torch.max(outputs.data, 1)\n",
    "    preds = preds.cpu()\n",
    "    total += len(outputs)\n",
    "    correct += (preds == targets).sum()\n",
    "    if wrong_count < 5 and preds != targets:\n",
    "        wrong_classify.append([inputs.data, preds, targets])\n",
    "        wrong_count += 1\n",
    "accuracy = 100 * correct.double() / total\n",
    "print('Accuracy of the model on the 10000 test images: %.2f %%' % (accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:40:05.933596Z",
     "start_time": "2019-11-06T14:40:05.197015Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAC/CAYAAACPMC8KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XucXfO9//H3J5MbCRq5iYhcJKmUEhXiVo2iVZe6lcPRSH+CqKbRlqIOpT11aNFWxaVBJO7qVhxaNCgV1SRoiBAhQYhEIpGbXGbm+/tjr/RM57Mms/bs21p7Xs/HYx6T+ey11/qumfesrO9ee33GQggCAAAAAKRHm0oPAAAAAADw75ioAQAAAEDKMFEDAAAAgJRhogYAAAAAKcNEDQAAAABShokaAAAAAKQME7VWysyeMbNTKz0OIB/kFlljZpPM7BeVHgeQFMdZZE01H2dbzUTNzOab2UFl3mYws9Vmtir6uKmc2y+EmX3fzOaZ2Qozm25m+1V6TK1RhXI71MxmmNma6PPQcm6/GMzsluj3b2Clx9LaVCizR5jZa9FxdqqZfaGc2y8GM7s4ymxZv3fg/CBfZtbdzO40s+VmtszM7qj0mFqbcmfWzLqZ2fNmtjT6ub9gZvuWa/uFMLOTGvyerYrObYKZ7V7psSXRaiZqzTGztiVa9a4hhM7RR0leoSr22M1suKTLJX1L0laSbpb0oJnVFHM7KFwJfvbtJT0k6XZJXSRNlvRQVC+qUv3ORS8q7FCKdaNwJcjsIEl3SDpD0uckPSLp4VLkq4SZ3UG54+3CUqwfheH8wHlA0keS+krqIenKEmwDBSjBz32VpFMkdVfu3OCXkh7JwnE2hHBHg9+zzpLOlPSOpJeKuZ1SaRUTNTO7TdL2yoVqlZmda2b9ohn1aDN7T9JTZjbCzBY0eu6/XrUwszZmdr6ZvR29qvAHM9u6RGMOZjbOzN4xsyVmdoWZtYke+070ysZvzOwTSZdE9VPMbHb0CtfjZta3wfoONrM3zOxTMxsvyTax+X6SZoUQZoQQgqRbJXVT7oCMMqlQbkdIaivptyGEdSGE3ymXla8mHHMlc7vxAH+NpLFJxoviqlBmvy7puRDC30IItcqdQPSW9JWEY55vZj8xs9ejDN5iZh2jx0aY2QIzO8/MPpJ0S1Q/3Mxesdwry1PNbJcG69vNzF4ys5Vmdo+kjgmGMV7SeZLWJxkziofzg/yOs2b2NUl9JP04hPBpCGFDCOHlUuwn4lUisyGEtSGEN0MI9crlo065CVuijKfkOLvRKEm3Rue3qdcqJmohhJGS3pN0RDSj/lWDh78iaYhy/9k3Z5yko6LnbCtpmaRrNz5oZjPN7D8bPedZM/vIzB4ws355Dv1oScMkfUnSkcq9mrHRcOVeEegh6VIzO0rSBZKOUe4Vj+ck3RWNq5uk+yVdqNyE621J/7pkbWbbR78I20elP0mqMbPhlruKdoqkV5R7BQ1lUqHc7iRpZqMD2MyonlSlcitJP5T0bAhhZh7jRZFUKLOmfz+x3Pj1znkM/aRoXDtIGqxc5jbaRrmTkb6STjezL0maKGmMpK6Sfq/cFbwOlrvy/EdJt0XPuVfSsQ03FGV2vwZfHydpfQjhsTzGiyLh/CDv4+xekt6UNDk6uZ9mZoleFEFxVDKzZjZT0lpJD0u6KYSwOI+hV+w426DeV9L+yl2AyIYQQqv4kDRf0kENvu4nKUga0KA2QtKCpp4nabakAxs81kvSBkltm9jm/pLaK/d2nPGSXmtq2ZjnBkmHNPj6TElTon9/R9J7jZb/k6TRDb5uI2mNcqE/WdLfGzxmkhZIOrWJbZtyB/UNkmolLZG0R6V/hq3xo9y5lXSRpLsb1e6QdEnC8VYyt30kzZW0VYOxDKz0z7C1fVQgsztKWh2ts32U4XpJP8ljvGc0+PpQSW83GOd6SR0bPH69pP9utI43lTvZ2V/Sh5KswWNTJf2iiW13lvSWpP5x3zs+qjOz0eNZPT+YEG1/tKR2kk6QtFxSt0r/HFvTRyUy22C5jpJOlDQqz/FW5DjbaB0XSXqm0j+/fD5axRW1Zryfx7J9lbtXa7mZLVcu5HWSesYtHEJ4NoSwPoSwXNJZkvor90pHS8b2rnKveDQ17r6Srm4wtk+UO+D2jp73r+VDLq2b2u9TlXt1bifl/iP5tqT/NbNtN/EclFepcrtK0paNaltKWtnCsZUzt7+V9PMQwqd5jBXlU5LMhhDeUO6tLOOVu8erm6TXlTvZbMnYGmf24xDC2kZjO3vj2KLx9Ymes62kD6KsNlxfU34m6bYQwrw8xory4fzA+0zS/BDCzSH3tse7o+Uz0ViiFShZZjcKubdB3iXpfDPbtYVjK+dxtqGTlbv3PjNa00StqfeiNqyvlrT5xi+it/11b/D4+5K+EUL4XIOPjiGED/IYwybvsWmkT4N/b6/cKwhx4944tjGNxrZZCGGqcicv/1qXmVmjdTe2q6RHQghzQgj1IYQ/R+vYJ4+xozjKndtZknaJMrLRLlE9qUrl9kBJV0RvJdr4Nt0XYt5uhNIq+7E2hHBfCGHnEEJXSRcr95/8tDzGnG9mL200ts2jE5eFkno3+v3ZXk07UNK4BpntI+kPZnZeHmNH4Tg/iCQ4zs6MWT/KLw2ZbSdpQB5jrtRxVpJkuS6V20q6L48xV1xrmqgtUvOBmiOpo5kdZmbtlHv/bIcGj9+g3Pu9+0r/alF7ZNyKzGwny7U5rzGzzpKukvSBcq9YbLzhd34z4/mxmXUxsz7KveJ2zyaWvUHST8xsp2j9W0X3PkjSo5J2MrNjLNdsYZxy7wduyjRJh5nZAMs5WLn3E7/WzHhRfGXNraRnlHtFbVz0XvCNTTmeip6b5twOVu5FhqHRhyQdIenBZsaL4ip3ZmVmu0fH2u7K3cvwSHSlbeON6s2dWH7PzLaz3I30F2jTmb1R0hmWu4fXzKxTtB9bSHpBubeLjzOztmZ2jKQ9N7GuA5W7l25jZj9U7p6MazfxHBQf5wfJj7MPSupiZqOi8X9LuStzzzczXhRXuTO7l5ntZ2btzWyz6MWknpJejB5P83F2o1GS7g8h5PMOoYprTRO1yyRdaLlLqOfELRC9ZepMSTcpd9BcrX9/+8zVyt1A+YSZrZT0d+Vu2pUkmdksMzsp+rKnciFcodxNvf0kHR5C2BA93kfNH9gekjRDuUYejyrXJj9WCOFB5bqd3W1mK5SbVH0jemyJpOOUa7m/VNKghtu23M3Cq+z/bha+VdLdyp20r5D0O+VejXujmfGi+Mqa2xDCeuVuLj5ZufsOTpF0VFSXUpzbEMLiEMJHGz+ixZaEED5rZrwornIfazcuv1y5exiWSzqtwWN9lPuPfVPulPSEcsfqdyQ1+YdTQwjTo/WPV+7m+7nK3Re08ffnmOjrZZL+Q7lW5v8SZfbL0fJLG2W2TtKyEMKqZsaL4uL8IPlx9hNJ35R0jqRPJZ0v6choPSifcme2g3IvIC2N1nWopMNCCBuviqX2OBt93VHS8crY2x6l6EY8lJ+ZPSHprBDC7CYeD5IGhRDmlndkQNPILbLGcn9I+N4QwuNNPD5fucYJfynrwIAmcJxF1nCcLZ1S/RFHNCOE8LVKjwHIF7lF1oQS/SFhoFQ4ziJrOM6WTmt66yMAAAAAZAJvfQQAAACAlOGKGgAAAACkTEH3qJnZIcp1jamRdFMI4fJNLd/eOoSO6lTIJtGKrdVqrQ/r8vk7Mw6ZRTkVI7NSfrklsyjUSi1bEkLo3vySTSOzKKdyZ1YityhM0vODFk/ULPeH866VdLBy7T6nmdnDIYTXm3pOR3XScDuwpZtEK/dimFLQ88ksyq3QzEr555bMolB/Cfe9W8jzySzKrdyZlcgtCpP0/KCQtz7uKWluCOGd6G8a3C2pyT9ICqQAmUUWkVtkDZlF1pBZpFIhE7Xekt5v8PWCqPZvzOx0M5tuZtM3aF0BmwMKRmaRRc3mlswiZcgssobzA6RSIRO1uPdVuhaSIYQJIYRhIYRh7dShgM0BBSOzyKJmc0tmkTJkFlnD+QFSqZCJ2gJJfRp8vZ2kDwsbDlBSZBZZRG6RNWQWWUNmkUqFTNSmSRpkZv3NrL2kEyQ9XJxhASVBZpFF5BZZQ2aRNWQWqdTiro8hhFozGyvpceVamU4MIcwq2siAIiOzyCJyi6whs8gaMou0KujvqIUQHpP0WJHGApQcmUUWkVtkDZlF1pBZpFEhb30EAAAAAJQAEzUAAAAASBkmagAAAACQMkzUAAAAACBlmKgBAAAAQMowUQMAAACAlGGiBgAAAAApw0QNAAAAAFKGiRoAAAAApAwTNQAAAABIGSZqAAAAAJAyTNQAAAAAIGXaVnoAAAAASCdr197V6vbZqcXra/vSXFerX7myxesDqhlX1AAAAAAgZZioAQAAAEDKMFEDAAAAgJRhogYAAAAAKVNQMxEzmy9ppaQ6SbUhhGHFGFS1WXPMcFfb/IEXC1rnhoN2d7Wnbr3Z1epCfeJ1HvDasa7W7pdbu1rbp2YkXmcakVtkTavObJsaV/pk1J6uNu3S613ttPf3dbUPDusYu5m6JUtbMDg0pVVntoJqevZwtRWTO7val7q9n3idnWvWutp/9/DnG0n9cukQV/vrLpu1eH3FQmaRRsXo+nhACGFJEdYDlBO5RdaQWWQNmUXWkFmkCm99BAAAAICUKXSiFiQ9YWYzzOz0uAXM7HQzm25m0zdoXYGbA4pik7kls0ghMousIbPIGs5pkTqFvvVx3xDCh2bWQ9KTZvZGCOHZhguEECZImiBJW9rWocDtAcWwydySWaQQmUXWkFlkDee0SJ2CJmohhA+jz4vN7EFJe0p6dtPPqh5zb9/N1aZ/5VpX62D/cLV1V9cWtO029oKrbQjtC1rnlJ3vc7Whw7/vats9VdBmKq6157aafHaUbyqxbJA/rG17xdRyDKdkWnNmw147u9rUX4x3tQ0xp0zXbee/RXuc6I9pktTzmmxnJG1ac2bLpWbIIFe74fFbXK13zeYFbadWda5228reLV7fEx/5ZiIdNL/F6ysWMos0avFbH82sk5ltsfHfkr4m6bViDQwoBXKLrCGzyBoyi6whs0irQq6o9ZT0oJltXM+dIYQ/F2VUQOmQW2QNmUXWkFlkDZlFKrV4ohZCeEfSrkUcC1By5BZZQ2aRNWQWWUNmkVa05wcAAACAlCnGH7xuFT46ax9Xe/2Aq12tjTokWl87qyl4TEnMWu+blhw95Xuxy7ZZ5eMw6NczXI02R2iJVcfv5Wof72au1v7zK1zt17v8IXadO7b/m6v1rPG/g1/cfbSr9T9hZuw6kS7zjkzWCGHvi8e62vId/dFq8O2zY5/v2yUA6WDDfEMdSbrm/htcLWnjkB1v8+cBPafXxy7bJqZTz2YP+SZpSaWhcUhrt/6QPVztnht/62pd22zmajXmr/HUhfjsFGJx3RpX+9r4cxM/32IO6r2uyl7TKK6oAQAAAEDKMFEDAAAAgJRhogYAAAAAKcNEDQAAAABShmYiCdW387U2KZvn7vk/33e1XvfNdbXBi6YnXieNQ7Apbfv2ia2/dcZ2rvbUSVe4Ws8af6NyPn780QhXG7TZIld79cs3u9puF57lan1+kb0bjatJTbeurjbpuGtd7dE1W7lazykfulrXm951NZqGIM3eu8Q3LnvgO1fGLtuvrW8ccs3yAa726JgRrjZgakwzkHp+O7Ksfr+hrrb4nHWxy/76ixNdrUubjn6dMWeB9aE8OekWc37w0lnXJH7+hphx7jzAnyd//vzXXa1+5crE2ym1dM00AAAAAABM1AAAAAAgbZioAQAAAEDKMFEDAAAAgJRhogYAAAAAKUPXxwzY8elTfe2ny1xtm0X/dLW61atLMia0PquPHe5qA86ZHbvsH7f/Y0zVd3CaV7vW1U689BxX6z5jRex22iz42NUe+h/f+er/ff06V9t8IT1N02bxkYNdbc8OT7jakDtHutoO814oyZiAUpl/6d6u9udv/8rVto/p7tiUP3+pp6u1WfdKfgNDJt11l++Qu1VMJ8d8zFzvOyeOvOkHBa2z2G4Y7f9/l6S9O/jam0f5ZQ+7c7Sr2fPp+Z3hihoAAAAApAwTNQAAAABIGSZqAAAAAJAyzU7UzGyimS02s9ca1LY2syfN7K3oc5fSDhPID7lF1pBZZA2ZRdaQWWRNkmYikySNl3Rrg9r5kqaEEC43s/Ojr88r/vDSo88tb7javkvGutrtP7vS1fq3LexmzvZzfROG2ndeLmidrcAkkdsWW/KIb+rw4K5XuVqvGp/Nptz86fau9seRI1yt2wzfFKKpth/+Nmdp87kDXa3mEHO1pXvVulrXm5vYUHlMEplNpL7HukoPATmT1Noz26bGlWq6d3W1t8fu4GpTRl7har1qkjcOibP/tOWu9tzJu7ta/SuvF7SdDJukjGX2g/P2cbVnxvrsxDUOmbsh/lj5rRmnuVrfMxa7Wli50tX6rJ0au85KGb3Vd2Prr580vswjKY1mr6iFEJ6V9Emj8pGSJkf/nizpqCKPCygIuUXWkFlkDZlF1pBZZE1L71HrGUJYKEnR5x7FGxJQMuQWWUNmkTVkFllDZpFaJf87amZ2uqTTJamjCrukD5QDmUXWkFlkDZlFFpFblFtLr6gtMrNekhR99m9sjYQQJoQQhoUQhrVTzF+fA8onUW7JLFKEzCJryCyyhnNapFZLr6g9LGmUpMujzw8VbUQpVbe08VuapS6TfNODp871TRhGb/VeQdsee/wjrvbY5C+5Wu28dwvaTivQ6nLbWNsB/Vyt9kbfVOPBQRNdLa5xyI8/Gh67ncfm7ORq/a+OaQkyY2bs8wvR98Y3Xa3NWP+a1JxDb3C1w+Vvuq+wVpXZHi8sdbXZGza42tl7POlqj3bxTWTqli0rzsCQj6rMbNs+28XWF329j6u9+LNrE67VX5GZ8pk/+X9+tT+vkKSfdnvV1c7rOtvVvnDvB672++O+6WqtuMFIqjP79eP/7mpJG4ec9MuzY9e53fX+/DWuOVfatO21jatdetSdiZ8/a70/32mz1v8f01QDs0pI0p7/LkkvSPq8mS0ws9HKhflgM3tL0sHR10BqkFtkDZlF1pBZZA2ZRdY0e0UthHBiEw8dWOSxAEVDbpE1ZBZZQ2aRNWQWWdPSe9QAAAAAACXCRA0AAAAAUqbk7flbm5uu8jfofuWiq1xtYLvk3YJO32q+q40fdYSrbX8JzUSqXc2gAa62Ypfusct2OtPfQH5Er2muFtfs5oJFX3a1KTfu5Wrb3Dc3dtv9P/5nbL0c3j7L33hfrz+72s63jXO1AfI3WKN86l6f42q/+vAQV7ul7xRXu3cfv1yHR33eS2H5yL1dreep82KXrT2tk6vVzXm76GNCAcxc6Y2z45uJzDnONw6pMf8a+Hu1q1ztuAt/7GrdnvvQ1eqX+GZmkjT8+O+52hHj/upqF3Z7zdVW3/O4q91xzEGx26mb5Rs0oXyeusX/37tzT9/Ia+vZvgVG9zur6/+0ud/r72pHd3o08fOPfuZMVxs8Y0ZBYyo1rqgBAAAAQMowUQMAAACAlGGiBgAAAAApw0QNAAAAAFKGZiJF1vVmf+Pmt3qc42qvjL2moO2s65aFvyGPQtQM9DfNDr/vDVe7oNu9BW1n2LRvu1rv0YtcrftSn+00ptD8/dRaU7/B1Qbd8rGrpXF/Wrtlx/vmG3rRlz45zTdr6PWnGr9gfWE/ZdttJ1cbd+EfXK2d1cY+/5YF/vlIl5odB7raE0df2cTSm7vKCfO+6morDvcHps8t88fU+NTEizvf+MejfVxt0tO+Gcl3tvRNSy4cu1XsdgZ/N49Boeh6jJ9a6SFUxOpv+YYpz5x8RcySm8U+/8wF+7vajmP9OVR93iMrL66oAQAAAEDKMFEDAAAAgJRhogYAAAAAKcNEDQAAAABShmYiZdBpYUx3gwKd/dXHXO2x7XZxtdoFHxR92yi+tgP6udrr53ZztYe63Zd4nS+v97fIjrnqLFfrfZe/ubZuqb/5PJX28pnf+YC3XK1zmw6u9tbo7q424Ny5xRkXiias8k1C4szY43ZX++IlY12t7099A4amrD1iT1cbd9XdrvbNTstc7RsnnRa7zpo1LyXePiqjbrY/hhz0lD92StI9I25wtRXH+FOrumWLCx9YArUf+UZQUz/1zVHimomM3vfZ2HVO3bKHq9WtWNGC0QHxFv5xiKuN3+X3rtatxjcOWVL3Wew6X/m9Pz/YenXy439acEUNAAAAAFKGiRoAAAAApAwTNQAAAABIGSZqAAAAAJAyzU7UzGyimS02s9ca1C4xsw/M7JXo49DSDhNIjswii8gtsobMImvILLImSdfHSZLGS7q1Uf03IYQriz4iJHL6VvNdbfKBh7tal8mtsuvjJKU4s+9dvI+rjTned/H84+fedrXYTo5Xxncj6/r6Wlfr8fRUV6uLfXbl1AwZ5GpvntY1dtnLDr/L1Y7u5DtW/mnNFq62wz2+a1nx+7PmZZJSnNtKqV+12tX2uPz7rjbt/Gtcbfopv3G1iw/zv39N+VnP37laB2vnaoMf+a6rff5v8d0dK5yxYpukDGU2rrvum9/t5WpbzDNX6/2/8UfK8+72P/v2i6bnP7giadu/r6s999S2fsFRvsPjT7q+HrvO516e42r/db7vatr53hcTjLDiJilDmc2yRePij7UrBvrfpUeH+mP1wHa+W3Ocr048N7a+/UR/vpNFzV5RCyE8KykjvboBMotsIrfIGjKLrCGzyJpC7lEba2Yzo8vIXZpayMxON7PpZjZ9g9YVsDmgYGQWWdRsbsksUobMIms4P0AqtXSidr2kHSQNlbRQ0lVNLRhCmBBCGBZCGNZOyS5jAiVAZpFFiXJLZpEiZBZZw/kBUqtFE7UQwqIQQl0IoV7SjZL2LO6wgOIis8gicousIbPIGjKLNEvSTMQxs14hhIXRl0dLem1Ty6P4FtR+5modl6etLUR6pCmzs8dc52obQrKf3ZiZI12t54T4G9fDhvX5DayI2mzhm3eEHfsleu66yz51tTeGXJt424vq/O/GdSeO8uOZkf7DVppyWymhttbVtrn+H642eIhv6nD7ITe42sitX4jdzk7t4/479I1Dhjxzqqt9fqxvHBI37tYgzZl987/9O9reHJHs2DJjffwx+ufvftPVPtghWcOaTov9OjvdV1hDjs926OZqt54wPmZJ3zClKaMfOd3VBt7793yGlWppzmyWdf3mgtj6tCEPxlSTXZ38/F98E5vBlzZxDpRojenX7ETNzO6SNEJSNzNbIOliSSPMbKhy34f5ksaUcIxAXsgssojcImvILLKGzCJrmp2ohRBOjCnfXIKxAEVBZpFF5BZZQ2aRNWQWWVNI10cAAAAAQAkwUQMAAACAlGlRMxFU3nVLv+xqmz3kb7BH+sxa75tdDG7XPtFz/zHsDlcb8chxscuuXFu51sHbf265qz04aFLRt/Py+npX+48pP3C1wdPjbzZGNsU16hh8pj/+/VxfcrW2vbeNXedxU3xGjuw039U6vrpZovEgfQ4YOKfFz929fU1s/aFBj/rifyVb57rgc/PDc7/ias88PjT2+QMu+6ertfvLDFc7aapvgDPnAP9uv/tXx//5sB2v8k0hSDyas2iFbyomSc+v9U2ahnVY42odzC8XPvPTlko2TisHrqgBAAAAQMowUQMAAACAlGGiBgAAAAApw0QNAAAAAFKGZiJAmf1g1Jmu9r2b7nW1jrbB1fq09U06nvmif25WzFrvb0l/cvUXXO22iV+Pff62z610tcHTaByCTWgT//pk3O/b998/1NV6Xz616ENCefz16V18ceRzBa3zyLcOc7WBW3zsaldt45vddDB/CnZd7+f9Rk6JqUn62jOn+XW+Ms/Vvr/b07HPb+zYTsti65M6+KYOQHN6HzMrtn6Z/O/hd9+a62qHbf5p0ceURVxRAwAAAICUYaIGAAAAACnDRA0AAAAAUoaJGgAAAACkDM1EymBdF6v0EJAibf76sqtdP2hgoueGvXd1tQ9GdCp4TJWy/WP+5vX6f852tV6Kb+AQij4iVLs1X9gmtn5s5yWuduHDQ1xtB71Q9DGhPAb9/kNXu+0on4eRW3zkanev6h67zvox/vj75oIVrnZEr2Ndbc4ZPf1zT7zO1aatiz/SrTnHN5e6bMfHXW339jV+2xvWutqZY8bFbqfD/Fdi60BLvPfTfVztoM3+7mpzN9S72pCL57taXVFGlV5cUQMAAACAlGGiBgAAAAApw0QNAAAAAFKGiRoAAAAApEyzEzUz62NmT5vZbDObZWZnRfWtzexJM3sr+tyl9MMFmkdmkTVkFllEbpE1ZBZZk6TrY62ks0MIL5nZFpJmmNmTkr4jaUoI4XIzO1/S+ZLOK91Qs2H1t4a72uRxv4lZsrCGm++u2TqmurSgdVaRqs2svfBPV9suw03ofE+nVqtqM5s27x1Cs+MiylRua+e962oTLj7G1T73i9td7YTOH8eu8+C/THa14Y+f5Wo9n/G5+/J+s2LX2dgeHeI7R/9tl3tjqr7D4xkLvuxqr139RVfb8nHfeU+quu66mcpsVtQeuLurrfyR734qSaf09Z1JO1g7V3u/tqOr1S1a3ILRZVuzV9RCCAtDCC9F/14pabak3pKOlLTxCDVZ0lGlGiSQDzKLrCGzyCJyi6whs8iavO5RM7N+knaT9KKkniGEhVIu+JJ6NPGc081suplN36B1hY0WyBOZRdaQWWRRvrkls6g0jrXIgsQTNTPrLOl+ST8IIcRfz4wRQpgQQhgWQhjWTh1aMkagRcgssobMIotaklsyi0riWIusSDRRM7N2ygX6jhDCA1F5kZn1ih7vJan1vXEUqUVmkTVkFllEbpE1ZBZZ0uxd1WZmkm6WNDuE8OsGDz0saZSky6PPD5VkhGW24j/3iq2v6u3ntKv71bnaI4f91tUGt2tf0JjuWNnLj2d0XEMimolIrS+zyD4yWz6b90v84rkG3bbM1WiA83+qIbdb3OMbaIxffLyrDZv8u9jn96rZ3NXmfmOCX/Ab+Y+tOT9f4huC3Pn4/q4C1XLsAAAIcElEQVQ2+PoPXW3LefGNQ6pdNWS20tYesaerLRr5mau9OvTuxOu8f1U3V7t59JGu1kavJF5ntUjS/mpfSSMlvWpmG79DFygX5j+Y2WhJ70k6rjRDBPJGZpE1ZBZZRG6RNWQWmdLsRC2E8DdJ8b1hpQOLOxygcGQWWUNmkUXkFllDZpE1eXV9BAAAAACUHhM1AAAAAEiZJPeotSrjL42/YXiX9jUJ19DyxiEL6/zNmJL0Pw8c62r933yhxdsBgNZq9cqOiZede5Jv2jRgZjFHgzSqefolVxu9/X6xyy4Zs7errdjBL/f1A/w6k3rsxaGx9SGXve9qAz7w5wa1Ld4yWrs1xwx3td9edY2rxZ0jf1q/NnadP1rgO+t8dM4AV2vzfOtrHBKHK2oAAAAAkDJM1AAAAAAgZZioAQAAAEDKMFEDAAAAgJShmUiFbAh1rnbymB/GLtv/TzQOAYBiGHzNhtj63P3Xudq3D/mrq/3jhkGuVjvv3cIHhkzq9nv//3O3mOXeKmAbg/RibJ0mIWgp69DB1T47eFdXu+hXE10taXO9PR/8UWx90DifZxONQ5rCFTUAAAAASBkmagAAAACQMkzUAAAAACBlmKgBAAAAQMrQTKSR/7z9rNh69z0WudrTX7w30TpvWO7/4vrEaw9ztR5/mppofQCAlgnTXo2tH3Xb2a42/KBZrrauX1dXq6GZCIAUWnrq3rH1Ad+Z42qP9L++xdv5w6oerjbw7rUtXh/+D1fUAAAAACBlmKgBAAAAQMowUQMAAACAlGGiBgAAAAAp02wzETPrI+lWSdtIqpc0IYRwtZldIuk0SR9Hi14QQnisVAMtl34XvZB42cO1e4u300M0DimV1pZZZB+Zrby4Y/+ii/xyNXqpDKNJPzKLLKrm3C4/2TcOufeiK2KX3a7tZonWuazeNwTZ745zXG3gbZ+4ms36Z6JtYNOSdH2slXR2COElM9tC0gwzezJ67DchhCtLNzygRcgssobMImvILLKI3CJTmp2ohRAWSloY/Xulmc2W1LvUAwNaiswia8gssobMIovILbImr3vUzKyfpN0kvRiVxprZTDObaGZdmnjO6WY23cymb9C6ggYL5IvMImvILLKGzCKLyC2yIPFEzcw6S7pf0g9CCCskXS9pB0lDlXt14qq454UQJoQQhoUQhrVThyIMGUiGzCJryCyyhswii8gtsiLRRM3M2ikX6DtCCA9IUghhUQihLoRQL+lGSXuWbphAfsgssobMImvILLKI3CJLknR9NEk3S5odQvh1g3qv6L2+knS0pNdKM0QgP2QWWUNmkTVkFllUzbnt8uoKV/vaPT+OXfb1k8a72gWLhrnalAl7uVr/G3yH3LokA0SLJOn6uK+kkZJeNbNXotoFkk40s6GSgqT5ksaUZIRA/sgssobMImvILLKI3CJTknR9/Jski3koU39fAq0HmUXWkFlkDZlFFpFbZE1eXR8BAAAAAKXHRA0AAAAAUibJPWoAAAAAUiq8PMvVBrwcv+zh5+4etwZX6S7fOATlxRU1AAAAAEgZJmoAAAAAkDJM1AAAAAAgZZioAQAAAEDKWAj+5sGSbczsY0nvRl92k7SkbBsvrWraFym9+9M3hNC9nBsks5mR1v0hs8VTTfsipXt/yprbKs6sVF37k+Z9qeSxNs3fl5aopv1J874kymxZJ2r/tmGz6SGEYRXZeJFV075I1bc/xVJN35dq2hep+vanWKrp+1JN+yJV3/4US7V9X6ppf6ppX4qp2r4v1bQ/1bAvvPURAAAAAFKGiRoAAAAApEwlJ2oTKrjtYqumfZGqb3+KpZq+L9W0L1L17U+xVNP3pZr2Raq+/SmWavu+VNP+VNO+FFO1fV+qaX8yvy8Vu0cNAAAAABCPtz4CAAAAQMowUQMAAACAlCn7RM3MDjGzN81srpmdX+7tF8rMJprZYjN7rUFtazN70szeij53qeQYkzKzPmb2tJnNNrNZZnZWVM/k/pQKmU0PMpsMmU0PMptclnNbTZmVyG1SWc6sVF25rdbMlnWiZmY1kq6V9A1JX5B0opl9oZxjKIJJkg5pVDtf0pQQwiBJU6Kvs6BW0tkhhCGS9pL0vejnkdX9KToymzpkthlkNnXIbAJVkNtJqp7MSuS2WVWQWam6cluVmS33FbU9Jc0NIbwTQlgv6W5JR5Z5DAUJITwr6ZNG5SMlTY7+PVnSUWUdVAuFEBaGEF6K/r1S0mxJvZXR/SkRMpsiZDYRMpsiZDaxTOe2mjIrkduEMp1ZqbpyW62ZLfdErbek9xt8vSCqZV3PEMJCKRcUST0qPJ68mVk/SbtJelFVsD9FRGZTisw2icymFJndpGrMbVX8jMltk6oxs1IV/IyrKbPlnqhZTI2/D1BhZtZZ0v2SfhBCWFHp8aQMmU0hMrtJZDaFyGyzyG0KkdtNIrMpVG2ZLfdEbYGkPg2+3k7Sh2UeQyksMrNekhR9Xlzh8SRmZu2UC/QdIYQHonJm96cEyGzKkNlmkdmUIbOJVGNuM/0zJrfNqsbMShn+GVdjZss9UZsmaZCZ9Tez9pJOkPRwmcdQCg9LGhX9e5Skhyo4lsTMzCTdLGl2COHXDR7K5P6UCJlNETKbCJlNETKbWDXmNrM/Y3KbSDVmVsroz7hqMxtCKOuHpEMlzZH0tqT/Kvf2izD+uyQtlLRBuVdTRkvqqlwnmbeiz1tXepwJ92U/5S7Tz5T0SvRxaFb3p4TfJzKbkg8ym/j7RGZT8kFm8/peZTa31ZTZaH/IbbLvU2YzG42/anJbrZm1aOcAAAAAAClR9j94DQAAAADYNCZqAAAAAJAyTNQAAAAAIGWYqAEAAABAyjBRAwAAAICUYaIGAAAAACnDRA0AAAAAUub/A864ALB12QTfAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1080x5400 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(15, wrong_count*15))\n",
    "for i, (img, preds, truth) in enumerate(wrong_classify):\n",
    "    img = img.reshape(28, 28).cpu().numpy()\n",
    "    plt.subplot(1, wrong_count, i+1)\n",
    "    plt.imshow(img)\n",
    "    plt.title('true:%i, pred:%i' % (truth, preds))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-03T07:12:44.915877Z",
     "start_time": "2019-11-03T07:12:44.913744Z"
    }
   },
   "source": [
    "### Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:40:05.941535Z",
     "start_time": "2019-11-06T14:40:05.935673Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'bi-rnn_cuda.pkl')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 511.99678000000006,
   "position": {
    "height": "533.984px",
    "left": "411.875px",
    "right": "20px",
    "top": "161.984px",
    "width": "760.516px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "block",
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
